package com.rtsapp.agent.handler;

import com.rtsapp.agent.rules.AgentUtils;
import com.rtsapp.agent.rules.ClassNameCodex;
import com.rtsapp.server.logger.Logger;
import com.rtsapp.server.logger.LoggerFactory;
import com.rtsapp.server.utils.StringUtils;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.instrument.ClassDefinition;
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * Created by liwang on 17-3-28.
 */
class HotfixExecutor {

    private static final Logger LOGGER = LoggerFactory.getLogger( HotfixExecutor.class );

    // 存放热更新文件的目录名称
    public static final String ROOT_PATH = "hotfiles";
    // 存放已经执行过fix的文件的目录
    public static final String FIXED_PATH = ROOT_PATH + "/fixed";
    // 存放不支持hotfix的文件的目录
    public static final String USELESS_PATH = ROOT_PATH + "/useless";


    private final Instrumentation inst;


    // 处理jar包记录的日志, 在处理每个jar包之前调用init()方法初始化log数据
    private final FixResultLog fixLog = new FixResultLog( );

    // 标记hotfix是否正在执行. 利用原子性保证hotfix操作不会被并发执行
    private final AtomicBoolean runingTag = new AtomicBoolean( false );

    // 记录已经fix成功的class的字节码, 对比字节码时, 旧字节码优先取这里的, 如果没有再从jar包里获取
    private final Map<String, byte[]> fixedClassBytesMap = new HashMap<>( );


    public HotfixExecutor( Instrumentation inst ) {
        this.inst = inst;
    }


    public void startHotFix( String[] rules ) {

        // sync 1: 如果有其他线程正在执行, 则退出
        if ( !runingTag.compareAndSet( false, true ) ) {
            return;
        }

        try {

            doHotFix( rules );

        } finally {

            // sync 2: hotfix执行完毕, runTag设置为false, 确保下一次hotfix能正常执行
            runingTag.set( false );
        }

    }

    private void doHotFix( String[] rules ) {

        LOGGER.info( "rts-hotfix:  开始执行 热修复" );

        // 获取所有文件
        List<File> fileList = FileUtil.readFile( ROOT_PATH );
        int size = fileList == null ? 0 : fileList.size( );
        if ( size <= 0 ) {
            LOGGER.info( "rts-hotfix:  hotfiles下 没有需要处理的文件!!!" );
            return;
        }

        LOGGER.debug( "rts-hotfix:  hotfiles下有 {} 个文件需要处理 !", size );

        for ( int i = 0; i < size; i++ ) {
            File file = fileList.get( i );
            String fileName = file.getName( );

            if ( fileName.endsWith( ".jar" ) ) {
                dealWithJar( file, rules );

                // 将处理后的jar包移动到指定的目录中
                FileUtil.moveFile( file, FIXED_PATH + "/" + fileName );
            } else {
                LOGGER.error( "rts-hotfix: 不支持的文件类型!" );
                // 移动到指定的目录中
                FileUtil.moveFile( file, USELESS_PATH + "/" + fileName );
            }

//            LOGGER.info( "rts-hotfix: {} fix完毕. 删除文件 {} !", fileName, file.delete( ) ? "成功" : "失败" );
        }

    }


    /**
     * 更新一个jar包中的所有class文件到jvm
     *
     * @param file
     */
    private void dealWithJar( File file, String[] rules ) {
        LOGGER.info( "rts-hotfix:  开始处理 jar 文件: {}", file.getName( ) );
        fixLog.init( );

        JarFile oldJarFile = getJarFileByPath( this.getClass().getProtectionDomain().getCodeSource()
                .getLocation().getPath() );
        JarFile jarFile = getJarFileByPath( file.getPath( ) );
        if ( oldJarFile == null || jarFile == null ) {
            return;
        }

        // fixLog: jar包内的文件总数
        fixLog.setFileCount( jarFile.size( ) );

        Enumeration<JarEntry> jarEntrys = jarFile.entries( );
        while ( jarEntrys.hasMoreElements( ) ) {

            JarEntry jarEntry = jarEntrys.nextElement( );
            String jeName = jarEntry.getName( );

            // .class类型文件
            if ( !jeName.endsWith( ".class" ) ) {
                continue;
            }

            // fixLog: 记录jar包内.class文件数量
            fixLog.classCountIncrement( );

            // 使用自定义规则 验证类名
            if ( !testClassName( jeName, rules ) ) {
                continue;
            }

            // 获取Class对象
            Class c = getModifiableClass( getPackageName( jeName ) );
            if ( c == null ) {
                continue;
            }

            // 将.class文件转成字节码
            byte[] newBytes = getBytesFromJarEntry( jarEntry, jarFile );
            if ( AgentUtils.isEmpty( newBytes ) ) {
                continue;
            }

            // 判断.class文件字节码是否有改变
            if ( !isClassChanged( newBytes, jeName, oldJarFile ) ) {
                continue;
            }

            // fixLog: 记录需要redefine的.class文件数量
            fixLog.needRedefineCountIncrement( );


            /**
             *  执行 redefineClass
             */
            boolean redefineSuccess = redefineClass( c, newBytes );

            // 记录redefine成功后的字节码
            if ( redefineSuccess ) {
                fixedClassBytesMap.put( jeName, newBytes );
            }

            // fixLog: 记录redefine成功的.class文件名字
            fixLog.addRedefineClassName( jeName, redefineSuccess );

        }

        LOGGER.info( fixLog.getResultLogStr( ) );
    }

    private boolean testClassName( String className, String[] rules ) {
        if ( AgentUtils.isEmpty( rules ) ) {
            return true;
        }

        for ( String rule : rules ) {
            if ( AgentUtils.isEmpty( rule ) ) {
                continue;
            }

            if ( rule.endsWith( ".class" ) ) {
                if ( className.equals( rule ) )
                    return true;
            } else if ( rule.endsWith( "*" ) && rule.length( ) > 1 && className.startsWith( rule.substring( 0, rule.length( ) - 1 ) ) )
                return true;
        }

        return false;
    }

    /**
     * 获取指定path的jar文件
     *
     * @param jarPath
     * @return
     */
    private JarFile getJarFileByPath( String jarPath ) {
        if ( StringUtils.isEmpty( jarPath ) ) {
            LOGGER.error( "rts-hotfix:  jarPath 不能为空!" );
            return null;
        }

        try {
            return new JarFile( jarPath );
        } catch ( IOException e ) {
            LOGGER.error( e, "rts-hotfix:  jarPath有误! jarPath: {}", jarPath );
        }

        return null;
    }

    /**
     * 使用类名获取Class对象, 如果该Class不支持redefine, 则返回null
     *
     * @param className
     * @return
     */
    private Class getModifiableClass( String className ) {

        Class c = null;

        try {
            c = Class.forName( className );
        } catch ( NullPointerException e ) {
            LOGGER.error( e, "rts-hotfix:  Class.forName(): 类名不能为空!" );
            return null;
        } catch ( ClassNotFoundException e ) {
            LOGGER.error( e, "rts-hotfix:  此类未加载! className: {}", className );
            return null;
        } catch ( NoClassDefFoundError e ) {
            LOGGER.error( e, "rts-hotfix:  className: {}", className );
            return null;
        } catch ( Throwable e ) {
            LOGGER.error( e, "rts-hotfix:  Class.forName( className ) :    Throwable !!  className: {}", className );
            return null;
        }

        return inst.isModifiableClass( c ) ? c : null;
    }

    /**
     * 获取类的包名+类名
     *
     * @param path
     * @return
     */
    private String getPackageName( String path ) {
        path = path.split( "\\." )[ 0 ];
        path = path.replaceAll( "\\" + File.separator, "." );
        return path;
    }

    /**
     * 将jarEntry转换为字节码byte[]
     *
     * @param jarEntry
     * @param jarFile
     * @return 发生错误的话返回null
     */
    private byte[] getBytesFromJarEntry( JarEntry jarEntry, JarFile jarFile ) {
        ByteArrayOutputStream out = null;

        try {
            InputStream in = jarFile.getInputStream( jarEntry );
            byte[] bytes = new byte[ 1024 ];
            int len;
            out = new ByteArrayOutputStream( );
            while ( ( len = in.read( bytes ) ) != -1 ) {
                out.write( bytes, 0, len );
            }
            in.close( );
            out.close( );

        } catch ( IOException e ) {
            LOGGER.error( e, "rts-hotfix:  jarEntity转bytes出错! jarName:{} , className: {}", jarFile.getName( ), jarEntry.getName( ) );
        }

        return out == null ? null : out.toByteArray( );
    }

    /**
     * 判断class字节码是否有改变
     *
     * @param newClassBytes
     * @param newClassName
     * @param oldJarFile
     * @return
     */
    private boolean isClassChanged( byte[] newClassBytes, String newClassName, JarFile oldJarFile ) {

        // 原始类字节码
        byte[] oldClassBytes = fixedClassBytesMap.get( newClassName );
        if ( AgentUtils.isEmpty( oldClassBytes ) ) {

            // 如果内存中保存的oldBytes为空, 那么从原始jar包中获取

            JarEntry oldJarEntry = oldJarFile.getJarEntry( newClassName );
            if ( oldJarEntry == null ) {
                LOGGER.error( "rts-hotfix:  类型: {} 在原始jar包中未找到!", newClassName );
                return false;
            }

            // 原始类字节码
            oldClassBytes = getBytesFromJarEntry( oldJarEntry, oldJarFile );
            if ( AgentUtils.isEmpty( oldClassBytes ) ) {
                return false;
            }

        }


        return !Arrays.equals( newClassBytes, oldClassBytes );
    }


    private boolean redefineClasses( ClassDefinition... defs ) {
        try {
            inst.redefineClasses( defs );
            return true;
        } catch ( ClassNotFoundException e ) {
            LOGGER.error( e, "rts-hotfix:  未加载此类, 无需更新!" );
        } catch ( UnmodifiableClassException e ) {
            LOGGER.error( e, "rts-hotfix:  此类型不支持热更新!" );
        } catch ( Throwable e ) {
            LOGGER.error( e, "rts-hotfix:  redefineClasses 发生未知错误!" );
        }

        return false;
    }

    private boolean redefineClass( Class c, byte[] bytes ) {
        LOGGER.info( "rts-hotfix:  开始 redefineClass, className: {}", c.getName( ) );
        return redefineClasses( new ClassDefinition( c, bytes ) );
    }


    /**
     * 获取类的包名+类名
     *
     * @param file
     * @return
     */
//    private String getPackageName( File file ) {
//        String path = file.getPath( );
//        int index = path.indexOf( HotfixHandler.ROOT_PATH );
//        path = path.substring( index + HotfixHandler.ROOT_PATH.length( ) + 1 );
//        path = path.split( "\\." )[ 0 ];
//        path = path.replaceAll( "\\" + File.separator, "." );
//        return path;
//    }


    /**
     * Created by liwang on 17-3-21.
     * 用于 执行hotfix的执行情况
     */
    private class FixResultLog {

        // jar包内的总文件数
        private int fileCount = 0;
        // jar包内的.class文件总数
        private int classCount = 0;
        // 需要redefine的.class文件总数
        private int needRedefineCount = 0;
        // 成功redefine的.class文件名字列表
        private final List<String> redefineSuccessList = new ArrayList<>( );
        // redefine失败的.class文件名字列表
        private final List<String> redefineFailList = new ArrayList<>( );


        public void init( ) {
            this.fileCount = this.classCount = this.needRedefineCount = 0;
            this.redefineSuccessList.clear( );
            this.redefineFailList.clear( );
        }

        public void setFileCount( int fileCount ) {
            this.fileCount = fileCount;
        }

        public void classCountIncrement( ) {
            this.classCount++;
        }

        public void needRedefineCountIncrement( ) {
            this.needRedefineCount++;
        }

        public void addRedefineClassName( String className, boolean redefineSuccess ) {
            ( redefineSuccess ? redefineSuccessList : redefineFailList ).add( className );
        }

        private final static String LOG_LIGHT_LINE = "\n------------------------------------------------------------------------------------------------------------\n";
        private final static String LINE_HEAD = "\n        ";
        private final static String LINE_HEAD_2 = "\n                ";

        public String getResultLogStr( ) {
            StringBuilder logSbd = new StringBuilder( LOG_LIGHT_LINE )
                    .append( LINE_HEAD ).append( "文件总数: " ).append( fileCount )
                    .append( LINE_HEAD ).append( ".class文件总数: " ).append( classCount )
                    .append( LINE_HEAD ).append( "需要fix的.class文件总数: " ).append( needRedefineCount )
                    .append( LINE_HEAD ).append( "成功执行fix的.class文件总数: " ).append( redefineSuccessList.size( ) )
                    .append( LINE_HEAD ).append( "fix成功 文件列表: " );
            for ( int i = 0; i < redefineSuccessList.size( ); i++ ) {
                logSbd.append( LINE_HEAD_2 ).append( i + 1 ).append( ". " ).append( redefineSuccessList.get( i ) );
            }

            logSbd.append( LINE_HEAD ).append( "fix失败 文件列表: " );
            for ( int i = 0; i < redefineFailList.size( ); i++ ) {
                logSbd.append( LINE_HEAD_2 ).append( i + 1 ).append( ". " ).append( redefineFailList.get( i ) );
            }

            logSbd.append( "\n" ).append( LOG_LIGHT_LINE );

            return logSbd.toString( );
        }


    }

}
